{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9eb3e7c0-c18a-4d68-8e01-62bacc584228",
   "metadata": {},
   "source": [
    "# 变分量子奇异值分解\n",
    "\n",
    "变分量子奇异值分解（Variational Quantum Singular Value Decomposition，VQSVD）是使用量子线路实现矩阵SVD分解，复现论文 [1] 内容。\n",
    "\n",
    "文中给出的 VASVD 算法流程如下：\n",
    "\n",
    "![](images/vqsvd_algorithm.png)\n",
    "\n",
    "其流程图如下：\n",
    "\n",
    "![](images/vqsvd_diagram.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4e8b68e-f37a-43d7-90b7-52a4cfb4ebb5",
   "metadata": {},
   "source": [
    "## Ansatz线路\n",
    "\n",
    "论文中给出4种线路结构，如下图所示，如果所分解的矩阵全是实数，则取 $U = R_y(\\alpha_j)$，如果包含复数则取 $R_z(\\theta_j)R_y(\\phi_j)R_z(\\psi_j)$。可根据图中给出的结构重复多次得到不同线路，由于实验线路大部分使用3-bit，因此主要除了(a)之外的线路只考虑3-bit情形。代码实现见函数 `get_ansatz`。\n",
    "\n",
    "![](images/ansatz_type.png)\n",
    "\n",
    "```python\n",
    "def get_ansatz(n_qubit: int, depth: int, kind='a') -> Circuit:\n",
    "    \"\"\"Get ansatz circuit.\n",
    "\n",
    "    Args:\n",
    "        n_qubit: number of qubit used.\n",
    "        depth: number of block repetation.\n",
    "        kind: circuit type, optional value: {'a', 'b', 'c', 'd'}.\n",
    "\n",
    "    Return:\n",
    "        Two ansatz circuits with different parameter names.\n",
    "    \"\"\"\n",
    "    ansatz = Circuit()\n",
    "\n",
    "    if kind == 'a':\n",
    "        for layer in range(depth+1):\n",
    "            for i in range(n_qubit):\n",
    "                ansatz += RY(f'{layer * n_qubit + i + 1}').on(i)\n",
    "            for i in range(n_qubit - 1):\n",
    "                ansatz += CNOT(i+1, i)\n",
    "    elif kind == 'b':\n",
    "        assert n_qubit == 3, \"this circuit only support n_qubit = 3\"\n",
    "        for layer in range(depth+1):\n",
    "            ansatz += Circuit([\n",
    "                RY(f'{8 * layer + 0}').on(0),\n",
    "                RY(f'{8 * layer + 1}').on(1),\n",
    "                CNOT(1, 0),\n",
    "                RY(f'{8 * layer + 2}').on(0),\n",
    "                RY(f'{8 * layer + 3}').on(1),\n",
    "                RY(f'{8 * layer + 4}').on(1),\n",
    "                RY(f'{8 * layer + 5}').on(2),\n",
    "                CNOT(2, 1),\n",
    "                RY(f'{8 * layer + 6}').on(1),\n",
    "                RY(f'{8 * layer + 7}').on(2)\n",
    "            ])\n",
    "    elif kind == 'c':\n",
    "        assert n_qubit == 3, \"this circuit only support n_qubit = 3\"\n",
    "        for layer in range(depth+1):\n",
    "            ansatz += Circuit([\n",
    "                RY(f'{3 * layer + 0}').on(0),\n",
    "                RY(f'{3 * layer + 1}').on(1),\n",
    "                RY(f'{3 * layer + 2}').on(2),\n",
    "                CNOT(1, 0),\n",
    "                CNOT(2, 1),\n",
    "                CNOT(0, 2)\n",
    "            ])\n",
    "    elif kind == 'd':\n",
    "        for layer in range(depth+1):\n",
    "            ansatz += Circuit([\n",
    "                RY(f'{6 * layer + 0}').on(0),\n",
    "                RY(f'{6 * layer + 1}').on(1),\n",
    "                RY(f'{6 * layer + 2}').on(2),\n",
    "                CNOT(1, 0),\n",
    "                CNOT(2, 1),\n",
    "                RY(f'{6 * layer + 0}').on(0),\n",
    "                RY(f'{6 * layer + 1}').on(1),\n",
    "                RY(f'{6 * layer + 2}').on(2),\n",
    "                CNOT(0, 2),\n",
    "                CNOT(1, 2)\n",
    "            ])\n",
    "    else:\n",
    "        assert kind in {'a', 'b', 'c', 'd'}, \\\n",
    "            \"Parameter kind should be one of {'a', 'b', 'c', 'd'}\"\n",
    "\n",
    "    ansatz_u = add_prefix(ansatz, 'alpha')\n",
    "    ansatz_v = add_prefix(ansatz, 'beta')\n",
    "    return (ansatz_u, ansatz_v)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20ae18d1-0b41-44dd-bb96-6d91bb8f3dd8",
   "metadata": {},
   "source": [
    "## 矩阵分解\n",
    "\n",
    "实验中需要对非厄米的矩阵分解，即 $M = \\sum_{k=1}^K c_k A_k$，其中 $A_k$ 是厄米矩阵，下面实现的 $A_k$ 为 X, Y或Z门对应矩阵的张量积。任何一个矩阵都可以分解成 Pauli 基张量积，例如对 $16\\times 16$ 的矩阵，可分解成\n",
    "\n",
    "$$\n",
    "M = \\frac{1}{16}\\sum_{i,j,k,l} h_{ijkl} \\sigma_i\\otimes\\sigma_j\\otimes\\sigma_k\\otimes\\sigma_l\n",
    "$$\n",
    "\n",
    "其中 $i,j,k,l\\in\\{0, 1, 2, 3\\}$，\n",
    "\n",
    "$$\n",
    "h_{ijkl} = Tr((\\sigma_i\\otimes\\sigma_j\\otimes\\sigma_k\\otimes\\sigma_l)^\\dagger \\cdot M)\n",
    "$$\n",
    "\n",
    "$\\sigma_0, \\sigma_1, \\sigma_2, \\sigma_3$ 分别对应 $I, X, Y, Z$ 矩阵。\n",
    "\n",
    "推广到大小为 $N = 2^n$ 的 $N\\times N$ 的矩阵，其中 $n$ 为量子比特数，则有\n",
    "\n",
    "$$\n",
    "M = \\frac{1}{N} \\sum_J Tr(\\tilde{\\sigma_J} H)\\tilde{\\sigma_J}\n",
    "$$\n",
    "\n",
    "其中，\n",
    "\n",
    "$$\n",
    "\\tilde{\\sigma_J} = \\prod_{k=1}^n \\tilde{\\sigma_{J_k}}^{(k)}\n",
    "$$\n",
    "\n",
    "$J\\in \\{0, 1, 2, 3\\}^n$ 是一个包含 $n$ 个整数的元组，$J_i\\in \\{0, 1, 2, 3\\}$， $\\tilde{\\sigma} = (I, X, Y, Z)$。\n",
    "\n",
    "针对输入的矩阵，分解过程实现代码如下：\n",
    "\n",
    "```python\n",
    "def decompose_matrix(n_qubit: int, in_mat: np.ndarray) -> List[tuple]:\n",
    "    \"\"\"Decompose a arbitrary matrix with the linear combination of tensor\n",
    "    product of pauli operators {X, Y, Z, I}.\n",
    "\n",
    "    Args:\n",
    "        n_qubit: number of qubit.\n",
    "        in_mat: input matrix that will be decompose.\n",
    "\n",
    "    Return:\n",
    "        List((coef, mat)): `mat` is tensor product of pauli operator and `coef`\n",
    "            is corresponding coefficient.\n",
    "    \"\"\"\n",
    "    def hs_product(m1, m2):\n",
    "        \"\"\"Hilbert-Schmidt-Product of two matrices `m1`, `m2`\"\"\"\n",
    "        return (np.dot(m1.conjugate().transpose(), m2)).trace()\n",
    "\n",
    "    def krons(arrs):\n",
    "        \"\"\"Get kron product of matrix array.\"\"\"\n",
    "        res = arrs[0]\n",
    "        for a in arrs[1:]:\n",
    "            res = np.kron(res, a)\n",
    "        return res\n",
    "\n",
    "    sx = np.array([[0, 1], [1, 0]], dtype=np.complex128)\n",
    "    sy = np.array([[0, -1j], [1j, 0]], dtype=np.complex128)\n",
    "    sz = np.array([[1, 0], [0, -1]], dtype=np.complex128)\n",
    "    ey = np.array([[1, 0], [0, 1]], dtype=np.complex128)\n",
    "    op_list = np.array([sx, sy, sz, ey])\n",
    "\n",
    "    # All combination of pauli gates on circuit.\n",
    "    items = list(it.product([0, 1, 2, 3], repeat=n_qubit))\n",
    "\n",
    "    result = []\n",
    "    for item in items:\n",
    "        mat = krons(op_list[list(item)])\n",
    "        coef = hs_product(mat, in_mat) / (2**n_qubit)\n",
    "        result.append((coef, mat))\n",
    "    return result\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8ba8fc3-134c-4955-b797-685c695e2e25",
   "metadata": {},
   "source": [
    "## 选择基矢量\n",
    "\n",
    "算法步骤中的基的选择比较简单，例如对3-bit，直接产生 $|000\\ket, |001\\ket, |010\\ket\\cdots$ 这样的8个基矢量即可，如果所需阶数`rank`小于 8，任取其中几个即可。代码实现如下：\n",
    "\n",
    "```python\n",
    "def get_basis(n_qubit: int, rank=5) -> List:\n",
    "    \"\"\"Get basis which are orthonormal each other and encoded with circuit.\n",
    "    Here just return $|0...00\\rangle$, $|0...01\\rangle$ ... in turn until\n",
    "    `rank` basis.\n",
    "\n",
    "    Args:\n",
    "        n_qubit: number of qubit used.\n",
    "        rank: number of basis.\n",
    "\n",
    "    Return:\n",
    "        Basis encoded with circuit.\n",
    "    \"\"\"\n",
    "    basis = []\n",
    "    for k, ps in enumerate(it.product([0, 1], repeat=n_qubit)):\n",
    "        if k >= rank:\n",
    "            break\n",
    "        cir = Circuit()\n",
    "        for i, p in enumerate(ps):\n",
    "            if p == 1:\n",
    "                cir += X.on(i)\n",
    "        basis.append(cir)\n",
    "    return basis\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9439100f-17b5-4080-8c2b-c28c34871025",
   "metadata": {},
   "source": [
    "## 训练\n",
    "\n",
    "训练过程按照论文的方式，计算损失 $L(\\alpha,\\beta)$，然后使用优化器更新参数即可，很多优化器需要传输各个参数的梯度值，如果没有获得整个函数梯度的接口，可根据线性关系将各组分梯度按系数加权即可得到总体参数的梯度。这里使用Mindspore提供的Adam优化器更新参数，其中为了最大化损失值，将梯度取相反数。训练的关键代码如下：\n",
    "\n",
    "```python\n",
    "    def train_one_loop(self):\n",
    "        \"\"\"Iterating over the basis and update the weight.\"\"\"\n",
    "        expect = 0.0\n",
    "        for i, (cir_left, cir_right) in enumerate(\\\n",
    "                zip(self.circuits_left, self.circuits_right)):\n",
    "            grad = np.zeros_like(self.weight, dtype=np.complex128)\n",
    "            fval = 0.0\n",
    "\n",
    "            for (coef, mat) in self.mat_item:\n",
    "                ham = Hamiltonian(sparse.csr_matrix(mat))\n",
    "                ops = self.sim.get_expectation_with_grad(ham, cir_right, cir_left)\n",
    "                f, g = ops(self.weight)\n",
    "                fval += coef * f[0][0].conj()\n",
    "                grad += coef * g[0][0].conj()\n",
    "            # Update weight\n",
    "            weighted_grad = self.q[i] * grad.real\n",
    "            self.update_weight(weighted_grad)\n",
    "            expect += self.q[i] * fval.real\n",
    "        return expect\n",
    "\n",
    "    def update_weight(self, grad):\n",
    "        \"\"\"Update weight.\"\"\"\n",
    "        if self.method == 'adam':\n",
    "            grad = -grad\n",
    "            grad = ms.Tensor(grad, dtype=ms.float64)\n",
    "            self.optimizer((grad,))\n",
    "            self.weight = self.optimizer.parameters[0].asnumpy()\n",
    "        else:\n",
    "            self.weight += self.lr * grad\n",
    "```\n",
    "\n",
    "补充：整个网络的训练也可以通过`scipy.minimize`的一些优化函数接口实现。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0cef47b1-5962-487e-b762-dd5b49807a9a",
   "metadata": {},
   "source": [
    "将相关接口封装，其中 `decompose_matrix` 实现将输入矩阵分解成算法步骤1中 $M = \\sum_{k=1}^K c_k A_k$ 形式，`get_basis` 返回步骤3中的基（Computational basis），`reconstruct_by_vqsvd` 传入相关参数，重建矩阵。`VQSVDTrainer` 封装参数训练过程。下面对主体函数 `run` 进行讲解，更多细节可查看 `src/` 目录下源代码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ac939ac5-ce64-422a-af47-c187dd0269dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from src.uitls import decompose_matrix, get_basis, reconstruct_by_vqsvd\n",
    "from src.vqsvd import VQSVDTrainer\n",
    "\n",
    "\n",
    "def run_demo(n_qubit, rank, ansatz_uv, in_mat, epoch=50, lr=1e-2, method='adam'):\n",
    "    \"\"\"Run QVSVD process with specific parameters.\n",
    "\n",
    "    Args:\n",
    "        n_qubit: number of qubit.\n",
    "        rank: rank T.\n",
    "        ansatz_uv: U and V ansatz circuits.\n",
    "        in_mat: the input matrix.\n",
    "        epoch: the epoch of training.\n",
    "        lr: learning rate of optimizer.\n",
    "\n",
    "    Return:\n",
    "        re_mat: the reconstructed matrix.\n",
    "    \"\"\"\n",
    "    # 分解矩阵\n",
    "    mat_item = decompose_matrix(n_qubit, in_mat)\n",
    "    # 变分线路U和V\n",
    "    ansatz_u, ansatz_v = ansatz_uv\n",
    "    # Algorithm 1中步骤2的正数q\n",
    "    q = np.arange(rank, 0, -1)\n",
    "    # 正交基底\n",
    "    basis = get_basis(n_qubit, rank)\n",
    "    # 封装的训练器\n",
    "    vqsvd = VQSVDTrainer(n_qubit, mat_item, rank, ansatz_u, ansatz_v, q, basis,\n",
    "                         lr=lr, method=method)\n",
    "    # 训练参数\n",
    "    vqsvd.train(epoch)\n",
    "    # 重建矩阵\n",
    "    re_mat = reconstruct_by_vqsvd(n_qubit, in_mat, basis, ansatz_u, ansatz_v,\n",
    "                                  weight=vqsvd.weight)\n",
    "    return re_mat"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1834d748-4b67-4571-9f76-82b442384660",
   "metadata": {},
   "source": [
    "下面以重建一个 `8 x 8` 的矩阵为例，说明函数使用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6f979192-c108-44d6-8fe1-fe385e7bbdff",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [01:53<00:00,  1.76it/s]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkgAAADSCAYAAACxfD+jAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAbTUlEQVR4nO3df5xddX3n8fd7JiHhl6LMqJCgQ9WGgiuCA4Kilh8KKIprrYLlR9xds+0W6g8qRXbrinZ9tN2uRfvDNQ1IhBBEwIqUooj4AxsokwhiCLhAg0mEZEYIJIQlJPPpH+eMuffkztw7k++de+7J6/l4zCMz9577vZ+ZnPfcz/me77njiBAAAAB26Ol0AQAAAGVDgwQAAFBAgwQAAFBAgwQAAFBAgwQAAFBAgwQAAFBAg5SQ7YttL0q9bQtjhe1XpRgL2N3ZXm37pE7XAaCzaJDGYXu+7ftsb7H9uO0v2d5vosdExOci4r+0Mv5ktk3F9rW23257lu3HC/ettL255mOb7W9NZ30oN9vH2f4X20/ZfsL2j20fZfsY28/Y3qfBY35i+zzbA3kjP7Z/rbd9k+23NXnOyMfebHud7c/b7m3fd7nT83/a9lXT9XxAN7L927bXdrqO1GiQGrB9gaS/kPQJSS+UdIykV0i61fYe4zxmxvRVOGWvlzQk6bWSflZ7R0QcFhH7RMQ+kvaVtEbS16e/RJSR7RdIuknS30h6saQ5ki6R9FxE3ClpraT3FR7zGkmHSlpac/N++T52uKRbJX3D9vwmT394/pi3SvqApP+0y98QMAHbt9j+TIPbT88PmGfkX7/R9vdsb8oPHG60fUjhMRfb/re8yV9r+2v57f/X9lcbPMfhtp+z/eK8QX8+H3+T7Z/b/lvbB0xQ+3zb2/Pne9r2vbZP2/WfSuuqMgtLg1SQvxBcIun8iLglIp6PiNWS3i9pQNJZ+Xaftn2d7atsPy1pfvFo0/Y5th+1/Svbf1q709RuW3N0fa7tX9gesf3fa8Y52vYy2xttP5YHpGGjNsH39SJJjognJA1KWjHB5m+R1Cfp+sk8ByrtNyUpIpZGxPaIeDYivhMRP83vXyzpnMJjzpF0c0T8qjhYRDweEV+Q9GlJf2G76e+iiHhI0o8lvW7sNtun2b4nz8a/2H5tzX1/ks86bbL9oO0T89uvsP1nNds1PPq1fYqkiyV9IH+xubdZjaiMxZLOsu3C7WdLWhIR22wfK+k7kr4p6UBJB0v6qaQf2x6QJNvn5o85KW/yByXdVvMc77W9d4PnuCn/XS1JX4uIfZUdmPxHSS+TtHyiJknSsvz59pP095KucZMzINgZDdLO3ihptqQbam+MiM2SbpZUe0rgdEnXKdsJl9Rub/tQZTvm70k6QNlM1Jwmz32cpHmSTpT0Kdu/ld++XdLHlDUtx+b3/7dWvhnbJ9reqGxGaG7++Rck/WH+ovLWBg87V9L1EfFMK8+B3cLPJW23vdj2qXnDXetKSW+xfZAk5Q3PB5W9CEzkBkkvUbbfTyg/Mn+zpIfyr4+QdLmk/yppf0lflnSjs1PI8ySdJ+mo/MXlZEmrW/lGx0TELZI+p+wFap+IOHwyj0dX+0dl+9Sbx27I9/nTJI3N+vylpK9GxBciYlNEPBER/0PSv0r6n/k2R0n6dkQ8LP36wGBh/vkySesk/U7Nc/Qqy81OM0v5wfpKZbOow5IuaPZNRMSosmzuLenV+XPMsv1X+cH4+nwma8/8vj5np743OjuN/qOxgxcX1roWDzRqbr9S0sslfSs/sLiwWZ1lRYO0sz5JIxGxrcF9j+X3j1kWEf8YEaMR8Wxh2/dJ+lZE3BERWyV9SlKzP3x3SX5kfq+ke5WdhlBELI+IOyNiWz6b9WVlpxuaiojbImI/ZYH/XWVN2mpJfRGxX0T8oHZ723vltV/RyvjYPUTE08oa+JD0D5KG89MJL83vXyPp+8qOfqWsiZ8l6Z+aDP3L/N8XT7DNCtvPSFqVP8ff57cvkPTliLgrn9VaLOk5ZafEt+fPf6jtmRGxeuxFCmgm/31+repnRd8v6YGIuDf/PflGNV6GcK2kt+ef3ynpHNufsD3ondfPfbXwHCdJmqnsYHy82rYrm7V683jbjMmf70OSnpf0aH7znyubEX6dpFcpe034VH7fBcpOl/dLeqmyGdRJ/cHWiDhb0i8kvSs/sPjLyTy+TGiQdjYiqc+N1xQdkN8/Zs0E4xxYe39EbJG006mGgtqF01sk7SNJtn8z7+ofz0/nfU71jdq4nJ3z3ijpTGVH8xuUrad6zPbnGzzkvZKekPSDBvdhNxYRqyJifkTMlfQaZfv4pTWbLNaOBulsSddExPNNhh2bVX1igm2OVJaFD0h6g7KjYSnbjy/Ij3Y35vv5QZIOzE/HfVTZKbwNtq+xfWDTbxLYYbGk99menX99jnbMiL5Y2evnYw0e95iyBkMRcZWk85XNYP5A2b74JzXbXinprbbn1jzH1S3k5pea+KDimDwP/1/SX0k6KyI25KcMF0j6WD7jtUnZ68kZ+eOeV/Y694p8xupHsRv/RXsapJ0tU3YU+t7aG51doXOqdpw/liburB+TNLbTK5/C3H+KNX1J0gOSXh0RL1DW1RfPjTeUv5idIum7+UzSQkl/mM8efbzBQ85VNm2824YCzUXEA8pmGV9Tc/MNyk7jHq8sP81Or0nZmooNkh5s8nwREdcqy+fY0e4aSf8r35fHPvaKiKX5Y66OiOOUNVKh7MILSXpG0l41w79soqdu4XtABUXEHcoOiN9j+5WSjpZ0dX73k5JGlTUTRXUH0hGxJCJOUrYU4/clfdb2yfl9v5D0Q2XrnfaR9B41OL3WwBxNfFBxZ/77/kWSbtSO2aZ+Zfv+8pqDilvy2yXpfys7hf0d24/YvqiFWiqLBqkgIp5Stkj7b2yfYntmvuDuWmVTj1e2ONR1kt7l7CqHPZQdybbU1DSwr6SnJW3O12H8wSQf/3rtWJR9pLIr2XaSH8Ucr9Ze2LAbsX2I7QvGjnTztUZnKjuFIEnK16xdJ+krkh6NiIb7Wf74l9o+T9lajU/mayVa8eeSPmz7ZcpO9f2+7Tc4s7ftd9re1/Y82yfYnqXsKPpZZS9oknSPpHc4u0roZcpmmsazXtKAW1hEjkoaOwV2lrK1ROulX+/ry5QtWyh6v7JTwXXyGZmvK1vIXXtgMTbz+juS/i0ilk9UUL4vvkvSj5oVn6+d/QNJZ+dr9kaUZeGwmoOKF+YLupWvpbogIn5D0rslfXzs4gZlZzV2qwMLQt9Afs70YmVTk09LukvZ0eqJEfFci2OsVDa1eo2y2aTNyo6UW3p8wR8rW7i3SdmLwtcm+fjXK1vHYUmHSFo5znZnK1tXxVoNFG1Sdnrrrnw90J3K3iqiuFB0sbIZm/GOgjfmj79P0jsk/W5EXN5qERFxn7Ij7k/kDdiHJf2tsiP6hyTNzzedpayZGlF26volkj6Z33elsjV+q5VdhTRRnsbWmPzK9kRXfqKavqpsXdCHtfOB40WSzrX9R3lT/qJ80fKblZ22Grvkfqxp77F9qqTDlL2mjLle2aLmSxo8x6/ZnpFfuLNUWXPSaInETvKr4RZJ+lR+IPIPkv7a9kvyceeMzWg5uyr0VflrxVPK1vLVHlh80Havsys8J1oHu17Sb7RSX6lFBB/T8KFsDcU2SQd3uhY++OCDDz5a+1A2G/SkpFkN7jsuv3+zslmTNZLeUHP/e5W9NcWTyg6275M0v8E4V+SvDwcWbv+0snVBm5WdGv5/yi5SmDNBvfMl3VG4ba6yg/PXKrtK+3OSHslrWiXpj/LtPqbswOEZZWdM/rRmjEFlB9eblB1kLJX0Z/l9vy1pbc22pytbqL1R0h93+v9wqh/Ovxm0ge13KVuzZEn/R9kR+JHBDx0AKsXZe3DdLumDEfHtTteDXccptvY6XdnVBr9U9h4UZ9AcAUD1RPamqe+R9B/GuQoaXYYZJAAAgAJmkAAAAApokAAAAAracp60r68vBgYG2jE0Sib1KVrv9Lchd83q1as1MjKSdtApIBO7DzLRXOo8jI62+jZandPTk34+ouz7Wju0Y1nQihUrRiKiv3h7WxqkgYEBDQ2N+x5xaFHqHaEdv0RSjzlz5syk4w0ODiYdb6oGBgZ09913JxuvHb/IuuGX7e6YiRkz0v6aPuqoo5KONxWpXyO2bNmSbKx2mT17dvONJin1vtYNTVw7GqSZM2c+2uh2TrEBAAAU0CABAAAU0CABAAAU0CABAAAUtNQg5X/V/kHbD9m+qN1FAWVHJoB6ZAJV07RBst0r6e8knSrpUEln2j603YUBZUUmgHpkAlXUygzS0ZIeiohHImKrpGuU/Y0xYHdFJoB6ZAKV00qDNEfSmpqv1+a3AbsrMgHUIxOonGSLtG0vsD1ke2h4eDjVsEDXIhPADuQB3aaVBmmdpINqvp6b31YnIhZGxGBEDPb37/SO3UCVkAmgXtNMkAd0m1YapLslvdr2wbb3kHSGpBvbWxZQamQCqEcmUDlN/8hPRGyzfZ6kb0vqlXR5RKxse2VASZEJoB6ZQBW19FcQI+JmSTe3uRaga5AJoB6ZQNXwTtoAAAAFNEgAAAAFNEgAAAAFNEgAAAAFLS3SrqLt27d3uoSmbHe6hKZ6e3s7XULXSPn/GRHJxhozOjqafMzUuiETPT1pjzu74XueipT7Wzfsu+3IbOrXsdT7bjvMmDF9bUv5fxoAAADTjAYJAACggAYJAACggAYJAACggAYJAACggAYJAACggAYJAACgoGmDZPty2xts/2w6CgLKjkwA9cgEqqiVGaQrJJ3S5jqAbnKFyARQ6wqRCVRM0wYpIn4o6YlpqAXoCmQCqEcmUEXJ1iDZXmB7yPbQ8PBwqmGBrkUmgB3IA7pNsgYpIhZGxGBEDPb396caFuhaZALYgTyg23AVGwAAQAENEgAAQEErl/kvlbRM0jzba23/5/aXBZQXmQDqkQlU0YxmG0TEmdNRCNAtyARQj0ygijjFBgAAUECDBAAAUECDBAAAUECDBAAAUNB0kfZUjY6OJhvrqaeeSjbWmNtuuy3peFdffXXS8STp9ttvTzrexo0bk44nSVu3bk06Xk9PdXv2lJlox/9l6kwsXbo06XjS7pmJ3t7epOOVhe1kY23evDnZWGO++93vJh1vyZIlSceTpKGhoaTjjYyMJB1PkrZt25Z8zOlS3VcjAACAKaJBAgAAKKBBAgAAKKBBAgAAKKBBAgAAKKBBAgAAKGjlj9UeZPt22/fbXmn7I9NRGFBWZAKoRyZQRa28D9I2SRdExArb+0pabvvWiLi/zbUBZUUmgHpkApXTdAYpIh6LiBX555skrZI0p92FAWVFJoB6ZAJVNKk1SLYHJB0h6a62VAN0GTIB1CMTqIqWGyTb+0i6XtJHI+LpBvcvsD1ke2h4eDhljUApkQmg3kSZIA/oNi01SLZnKtvpl0TEDY22iYiFETEYEYP9/f0pawRKh0wA9Zplgjyg27RyFZslXSZpVUR8vv0lAeVGJoB6ZAJV1MoM0psknS3pBNv35B/vaHNdQJmRCaAemUDlNL3MPyLukORpqAXoCmQCqEcmUEW8kzYAAEABDRIAAEABDRIAAEABDRIAAEBBK3+LrePWrVuXfMxZs2YlHW/p0qVJx5OkZ599Nul4n/nMZ5KOJ0k9PfTYndCOTMyePTvpeO3IxJYtW5KO99nPfjbpeBKZaEVEaOvWrcnGe/LJJ5ONNWb//fdPOt43v/nNpONJ0qZNm5KOd8kllyQdT5Kee+65pONNZ75IMgAAQAENEgAAQAENEgAAQAENEgAAQAENEgAAQAENEgAAQAENEgAAQEHTBsn2bNv/avte2yttp3+jBKCLkAmgHplAFbXyRpHPSTohIjbbninpDtv/HBF3trk2oKzIBFCPTKBymjZIERGSNudfzsw/op1FAWVGJoB6ZAJV1NIaJNu9tu+RtEHSrRFxV4NtFtgesj00PDycuEygXMgEUK9ZJsgDuk1LDVJEbI+I10maK+lo269psM3CiBiMiMH+/v7EZQLlQiaAes0yQR7QbSZ1FVtEbJR0u6RT2lIN0GXIBFCPTKAqWrmKrd/2fvnne0p6m6QH2lwXUFpkAqhHJlBFrVzFdoCkxbZ7lTVU10bETe0tCyg1MgHUIxOonFauYvuppCOmoRagK5AJoB6ZQBXxTtoAAAAFNEgAAAAFNEgAAAAFNEgAAAAFrVzF1nGHHXZY8jHnzZuXdLyenvS95jHHHJN0vPvvvz/peOicdmTikEMOSTpeOzJx7LHHJh2vHZnI/uoGJhIRGh0dTTZe6t/nkvTKV74y6Xgpv98xJ510UtLxli1blnQ8Kf333Y7fK+M+17Q9EwAAQJegQQIAACigQQIAACigQQIAACigQQIAACigQQIAAChouUGy3Wv7J7b5A4SAyARQRCZQJZOZQfqIpFXtKgToQmQCqEcmUBktNUi250p6p6RF7S0H6A5kAqhHJlA1rc4gXSrpQknjviWm7QW2h2wPDQ8Pp6gNKLNLRSaAWpdqgkzU5mFkZGRaCwOmommDZPs0SRsiYvlE20XEwogYjIjB/v7+ZAUCZUMmgHqtZKI2D319fdNYHTA1rcwgvUnSu22vlnSNpBNsX9XWqoByIxNAPTKBymnaIEXEJyNibkQMSDpD0vci4qy2VwaUFJkA6pEJVBHvgwQAAFAwYzIbR8T3JX2/LZUAXYhMAPXIBKqCGSQAAIACGiQAAIACGiQAAIACGiQAAICCSS3SnoyennL3Xr29vUnH27BhQ9LxJOnkk09OOp7tpONJ0vbt25OOV/b9Zlek/Pm34/8yIpKO1453D++GTKT+OVaRbc2Y0baXnyRS7xvr169POp4kHX/88UnHa8f/ybZt25KON535qu6rEQAAwBTRIAEAABTQIAEAABTQIAEAABTQIAEAABTQIAEAABS0dE2f7dWSNknaLmlbRAy2syig7MgEUI9MoGom86YHx0fESNsqAboPmQDqkQlUBqfYAAAAClptkELSd2wvt72gnQUBXYJMAPXIBCql1VNsx0XEOtsvkXSr7Qci4oe1G+SBWCBJL3/5yxOXCZQOmQDqTZgJ8oBu09IMUkSsy//dIOkbko5usM3CiBiMiMH+/v60VQIlQyaAes0yUZuHvr6+TpQITErTBsn23rb3Hftc0tsl/azdhQFlRSaAemQCVdTKKbaXSvpG/peNZ0i6OiJuaWtVQLmRCaAemUDlNG2QIuIRSYdPQy1AVyATQD0ygSriMn8AAIACGiQAAIACGiQAAIACGiQAAIACGiQAAICCyfyx2krZvn170vHOP//8pONJ0mWXXZZ8zNR6e3s7XULXyC+BTiIiko01ZnR0NOl47cjEokWLko+ZWk8Px53TrR15SD3mhRdemHQ8SfriF7+YdLzUvwOk9HmYMWP62haSDAAAUECDBAAAUECDBAAAUECDBAAAUECDBAAAUECDBAAAUNBSg2R7P9vX2X7A9irbx7a7MKDMyARQj0ygalp9Q4EvSLolIt5new9Je7WxJqAbkAmgHplApTRtkGy/UNJbJM2XpIjYKmlre8sCyotMAPXIBKqolVNsB0salvQV2z+xvcj23sWNbC+wPWR7aHh4OHmhQImQCaBe00zU5mFkZKQzVQKT0EqDNEPSkZK+FBFHSHpG0kXFjSJiYUQMRsRgf39/4jKBUiETQL2mmajNQ19fXydqBCallQZpraS1EXFX/vV1yoIA7K7IBFCPTKBymjZIEfG4pDW25+U3nSjp/rZWBZQYmQDqkQlUUatXsZ0vaUl+ZcIjkj7UvpKArkAmgHpkApXSUoMUEfdIGmxvKUD3IBNAPTKBquGdtAEAAApokAAAAApokAAAAApokAAAAApavYqtcmwnHW/16tVJx5Ok2bNnJx0v9feMzumG/8t2ZGLPPfdMOl43/ByrKuXPvre3N9lYY0ZHR5OO9/DDDycdT0r/GtEOPT1p52GmM7PMIAEAABTQIAEAABTQIAEAABTQIAEAABTQIAEAABTQIAEAABQ0bZBsz7N9T83H07Y/Og21AaVEJoB6ZAJV1PR9kCLiQUmvkyTbvZLWSfpGe8sCyotMAPXIBKposqfYTpT0cEQ82o5igC5EJoB6ZAKVMNkG6QxJS9tRCNClyARQj0ygElpukGzvIendkr4+zv0LbA/ZHhoeHk5VH1BaZAKoN1EmavMwMjIy/cUBkzSZGaRTJa2IiPWN7oyIhRExGBGD/f39aaoDyo1MAPXGzURtHvr6+jpQGjA5k2mQzhTTpkAtMgHUIxOojJYaJNt7S3qbpBvaWw7QHcgEUI9MoGqaXuYvSRHxjKT921wL0DXIBFCPTKBqeCdtAACAAhokAACAAhokAACAAhokAACAAhokAACAAkdE+kHtYUmt/B2ePkllf0vVstdY9vqkztb4iojo+Ls0VigTZa9PosZmOp6JSeRBKv//Z9nrk6ixmYaZaEuD1CrbQxEx2LECWlD2Gsten9QdNZZF2X9WZa9PosaqKfvPquz1SdQ4VZxiAwAAKKBBAgAAKOh0g7Sww8/firLXWPb6pO6osSzK/rMqe30SNVZN2X9WZa9PosYp6egaJAAAgDLq9AwSAABA6XSkQbJ9iu0HbT9k+6JO1DAR2wfZvt32/bZX2v5Ip2saj+1e2z+xfVOna2nE9n62r7P9gO1Vto/tdE1lRCbSIRPVQCbSIRNTM+2n2Gz3Svq5pLdJWivpbklnRsT901rIBGwfIOmAiFhhe19JyyW9p0w1jrH9cUmDkl4QEad1up4i24sl/SgiFtneQ9JeEbGxw2WVCplIi0x0PzKRFpmYmk7MIB0t6aGIeCQitkq6RtLpHahjXBHxWESsyD/fJGmVpDmdrWpntudKeqekRZ2upRHbL5T0FkmXSVJEbC3DTl9CZCIRMlEZZCIRMjF1nWiQ5khaU/P1WpVwpxpje0DSEZLu6nApjVwq6UJJox2uYzwHSxqW9JV8eneR7b07XVQJkYl0LhWZqAIykc6lIhNTwiLtCdjeR9L1kj4aEU93up5atk+TtCEilne6lgnMkHSkpC9FxBGSnpFUurUEaB2Z2GVkomLIxC4rbSY60SCtk3RQzddz89tKxfZMZTv9koi4odP1NPAmSe+2vVrZ9PMJtq/qbEk7WStpbUSMHVVdpywIqEcm0iAT1UEm0iATu6ATDdLdkl5t++B8MdYZkm7sQB3jsm1l50NXRcTnO11PIxHxyYiYGxEDyn6G34uIszpcVp2IeFzSGtvz8ptOlFS6BYwlQCYSIBOVQiYSIBO7ZsZ0P2FEbLN9nqRvS+qVdHlErJzuOpp4k6SzJd1n+578tosj4ubOldS1zpe0JP8l94ikD3W4ntIhE7sdMtEEmdjtlDITvJM2AABAAYu0AQAACmiQAAAACmiQAAAACmiQAAAACmiQAAAACmiQAAAACmiQAAAACmiQAAAACv4djNH3UfV5QlEAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 720x216 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from src.uitls import reconstruct_by_svd, get_ansatz\n",
    "\n",
    "\n",
    "def demo():\n",
    "    n_qubit = 3           # 量子比特数\n",
    "    depth = 20            # 线路深度\n",
    "    rank = 5              # 使用的秩数\n",
    "    max_epoch = 200       # 最大迭代次数\n",
    "    lr = 0.01             # 学习率\n",
    "    # 读取图片，大小为 8x8\n",
    "    in_mat = plt.imread('images/digit7_8x8.png')\n",
    "\n",
    "    # 使用经典 SVD 重建图片\n",
    "    re_mat_svd = reconstruct_by_svd(in_mat, rank)\n",
    "    # 使用 VQSVD 重建图片\n",
    "    ansatz_uv = get_ansatz(n_qubit, depth)\n",
    "    re_mat = run_demo(n_qubit, rank, ansatz_uv, in_mat,\n",
    "                      epoch=max_epoch, lr=lr)\n",
    "    # 绘图\n",
    "    plt.figure(figsize=(10, 3))\n",
    "    plt.subplot(1, 3, 1)\n",
    "    plt.imshow(in_mat, cmap='gray')\n",
    "    plt.title(\"Original #7\")\n",
    "    plt.subplot(1, 3, 2)\n",
    "    plt.imshow(re_mat_svd, cmap='gray')\n",
    "    plt.title(\"SVD Result\")\n",
    "    plt.subplot(1, 3, 3)\n",
    "    plt.imshow(re_mat, cmap='gray')\n",
    "    plt.title(\"VQSVD Result\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "demo()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a69566fb-97bf-480f-a544-ae6ae0432973",
   "metadata": {},
   "source": [
    "## 论文图片复现\n",
    "\n",
    "论文中其他图像复现代码如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cfe913f-8799-4fa2-a723-8b4ee84506c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.demo import plot_figure4\n",
    "# from src.demo import plot_figure5, plot_figure7, plot_figure7_light\n",
    "\n",
    "\n",
    "# 复现论文图4\n",
    "plot_figure4()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "726d0c87-b5e0-47a7-9de8-5ba0123b3bed",
   "metadata": {},
   "source": [
    "- 运行 `plot_figure4()`，论文 Fig.4 复现结果如下：\n",
    "\n",
    "![](images/figure4_bak.png)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa164181-abe7-4c07-9911-ef1c9c741fb7",
   "metadata": {},
   "source": [
    "- 运行 `plot_figure5(n_qubit=3)`, 使用 VQSVD 重建 `8 x 8` 矩阵，在 8u32G cpu 耗时大约 6 分钟。如果运行 `plot_figure5(n_qubit=5)` 则重建 `32 x 32` 矩阵，此时可能比较耗时，在 8u32G cpu 耗时大约 3 个小时。运行 `plot_figure5(n_qubit=3)` 结果如下，可以看到此时 VQSVD 能很好重建原图像。\n",
    "\n",
    "![](images/figure5_8x8_bak.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51268c42-c7bf-412f-95b5-fa6407b5731b",
   "metadata": {},
   "source": [
    "- 运行 `plot_figure7_light()`，复现结果如下图，可见相同参数数目下线路 (a) 效果最好。`plot_figure7_light` 采用 scipy 的 BFGS 优化器优化，`plot_figure7` 使用 mindspore 的 Adam 优化器优化，前者收敛速度更快。\n",
    "\n",
    "![](images/figure7_bak.png)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "756fdd5d-11fb-4846-a47f-2346a683b162",
   "metadata": {},
   "source": [
    "## 其他改进\n",
    "\n",
    "复现论文中 Figure.5 结果，图片大小为 `32x32` 时速度较慢，通过使用 BFGS 优化器加快收敛，另外通过获取量子态直接计算期望值，而不是通过将矩阵分解成哈密顿量再逐项计算，可以有效提高计算速度。需要说明的是，使用量子态直接计算期望值是仅在模拟器中演示算法总体可行性，现实中无法获得量子真实状态。通过改进，算法速度提高 6 倍左右，使用方法为 `plot_figure5_light`，其底层通过 `scipy.optimize.minimize` 封装，具体内容见函数 `utils.run_light`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22907fcd-e902-44b9-b40b-197662b326e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.demo import plot_figure5_light\n",
    "\n",
    "\n",
    "plot_figure5_light()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57ce5832-2d30-43d2-8f4f-4ed02a5fdbab",
   "metadata": {},
   "source": [
    "复现结果如下，可见 VQSVD 成功重建图像。\n",
    "\n",
    "![](images/figure5_bak.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8adc560-6f34-4052-b7d4-03dae42bc14d",
   "metadata": {},
   "source": [
    "## 参考文献\n",
    "\n",
    "[1] Wang, X., Song, Z., & Wang, Y. (2021). Variational Quantum Singular Value Decomposition. Quantum, 5, 483."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mindspore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
